import os
import gzip
import math
from collections import defaultdict
import pybedtools
from numpy import *


def select_dominant_promoters(promoters):
    filename = "promoters.FANTOM_CAT.THP-1.bed"
    tsss = {}
    counts = {}
    print("Reading", filename)
    lines = pybedtools.BedTool(filename)
    for line in lines:
        chromosome = line.chrom
        name = line.name
        if chromosome == 'chrM':
            count = 0
        else:
            count = float(line.score)
        counts[name] = count
        tsss[name] = line
    dominant_promoters = {}
    for gene in promoters:
        maximum = -1
        for name in promoters[gene]:
            count = counts[name]
            if count > maximum:
                maximum = count
                dominant_name = name
        if maximum > 0:
            dominant_promoters[gene] = tsss[dominant_name]
        intervals = []
        for name in promoters[gene]:
            chromosome, start, end, strand = name.split("_")
            start = int(start)
            end = int(end)
            count = str(counts[name])
            fields = [chromosome, start, end, name, count, strand]
            interval = pybedtools.create_interval_from_list(fields)
            intervals.append(interval)
        promoters[gene] = intervals
    return dominant_promoters

def generate_genes(dominant_promoters, promoters, ends, distances):
    for gene in dominant_promoters:
        promoter = dominant_promoters[gene]
        chrom = promoter.chrom
        start = promoter.start
        end = ends[gene]
        strand = promoter.strand
        if strand == '+':
            if start > end:
                continue
            tss = start
        elif strand == '-':
            if end > start:
                continue
            tss = start
            start = ends[gene]
        name = promoter.name
        distance = distances[gene]
        for promoter in promoters[gene]:
            assert promoter.chrom == chrom
            assert promoter.strand == strand
            if strand == "+":
                start = min(start, promoter.start)
            elif strand == "-":
                end = max(end, promoter.end)
        end += 1
        attributes = 'ID=%s;promoter=%s;TSS=%d;distance=%d;' % (gene, name, tss, distance)
        fields = [chrom, "FANTOMCAT", "gene", start+1, end, 0, strand, '.', attributes]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def read_fantomcat_acceptable_genes():
    geneIDs = []
    directory = "/osc-fs_home/mdehoon/Data/Fantom6/FANTOMCAT"
    filename = "F6_CAT.gene.info.tsv.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, 'rt')
    line = next(handle)
    words = line.split()
    assert len(words) == 41
    assert words[0] == "geneID"
    assert words[1] == "geneName"
    assert words[2] == "geneType"
    assert words[3] == "CAT_geneClass"
    assert words[4] == "cntg"
    assert words[5] == "geneStart"
    assert words[6] == "geneEnd"
    assert words[7] == "strnd"
    assert words[8] == "CAT_DHS_type"
    assert words[9] == "gapmerID"
    assert words[10] == "targetID"
    assert words[11] == "numPrmtr"
    assert words[12] == "numTrnscpt"
    assert words[13] == "prmtrID"
    assert words[14] == "trnscptID"
    assert words[15] == "HGNC_ID"
    assert words[16] == "HGNC_name"
    assert words[17] == "HGNC_symbol"
    assert words[18] == "HGNC_locus_group"
    assert words[19] == "HGNC_locus_type"
    assert words[20] == "HGNC_gene_family"
    assert words[21] == "alias_name"
    assert words[22] == "alias_symbol"
    assert words[23] == "prev_name"
    assert words[24] == "prev_symbol"
    assert words[25] == "entrez_ID"
    assert words[26] == "refseq_ID"
    assert words[27] == "VEGA_ID"
    assert words[28] == "UCSC_ID"
    assert words[29] == "CCDS_ID"
    assert words[30] == "uniprot_ID"
    assert words[31] == "homeodb_ID"
    assert words[32] == "cosmic_ID"
    assert words[33] == "lncrnadb_ID"
    assert words[34] == "mirbase_ID"
    assert words[35] == "snornabase_ID"
    assert words[36] == "orphanet_ID"
    assert words[37] == "pseudogene_org_ID"
    assert words[38] == "avg_exonNum"
    assert words[39] == "max_exonNum"
    assert words[40] == "note"
    for line in handle:
        words = line.split()
        geneID = words[0]
        if not geneID.startswith("ENSG"):
            assert geneID.startswith("CATG")
            trnscptIDs = words[14].split(",")
            for trnscptID in trnscptIDs:
                if trnscptID.startswith("ENCT"):
                    break
            # else:
                # continue
        geneClass = words[3]
        if geneClass == "small_RNA":
            continue
        geneIDs.append(geneID)
    handle.close()
    return set(geneIDs)

def read_ensembl_gene_promoters(geneIDs):
    directory = "/osc-fs_home/mdehoon/Data/Fantom6/FANTOMCAT"
    filename = "F6_CAT.promoter.1_to_1_ID_mapping.tsv"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = open(path)
    line = next(handle)
    words = line.split()
    assert len(words) == 8
    assert words[0] == "prmtrID"
    assert words[1] == "prmtrName"
    assert words[2] == "trnscptID"
    assert words[3] == "trnscptName"
    assert words[4] == "trnscptType"
    assert words[5] == "geneID"
    assert words[6] == "geneName"
    assert words[7] == "geneType"
    promoters = {}
    for geneID in sorted(geneIDs):
        promoters[geneID] = []
    for line in handle:
        words = line.split()
        assert len(words) == 8
        geneID = words[5]
        if geneID not in geneIDs:
            continue
        promoterID = words[0]
        promoters[geneID].append(promoterID)
    handle.close()
    return promoters

def read_gene_ends(geneIDs):
    ends = {}
    directory = "/osc-fs_home/mdehoon/Data/Fantom6/FANTOMCAT"
    filename = "F6_CAT.gene.bed.gz"
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = gzip.open(path, "rt")
    lines = pybedtools.BedTool(handle)
    for line in lines:
        geneID = line.name
        if geneID not in geneIDs:
            continue
        if line.strand == '+':
            end = line.end - 1
        elif line.strand == '-':
            end = line.start
        else:
            raise Exception("Unexpected strand")
        ends[geneID] = end
    handle.close()
    return ends


def find_gene_distances(promoters):
    gene_distances = {}
    tsss = defaultdict(lambda: {'+': [], '-': []})
    for gene in promoters:
        promoter = promoters[gene]
        chromosome = promoter.chrom
        strand = promoter.strand
        position = promoter.start
        tsss[chromosome][strand].append(position)
        gene_distances[gene] = math.inf
    for chromosome in tsss:
        for strand in '+-':
            tsss[chromosome][strand] = array(tsss[chromosome][strand])
    for gene in promoters:
        promoter = promoters[gene]
        chromosome = promoter.chrom
        position = promoter.start
        if promoter.strand == '+':
            strand = '-'
        elif promoter.strand == '-':
            strand = '+'
        else:
            raise Exception("Unexpected strand %s" % promoter.strand)
        distances = abs(tsss[chromosome][strand]-position)
        gene_distances[gene] = min(distances)
    return gene_distances

geneIDs = read_fantomcat_acceptable_genes()

promoters = read_ensembl_gene_promoters(geneIDs)

dominant_promoters = select_dominant_promoters(promoters)

distances = find_gene_distances(dominant_promoters)

ends = read_gene_ends(geneIDs)

genes = generate_genes(dominant_promoters, promoters, ends, distances)
genes = pybedtools.BedTool(genes)
genes = genes.sort()

filename = "genes.FANTOM_CAT.THP-1.gff"
print("Saving gene information to %s" % filename)
genes.saveas(filename)
